import torch
import torch.nn as nn
from mmcls.models import BACKBONES
from mmcls.models.backbones import VisionTransformer
from mmcls.models.utils import resize_pos_embed
from typing import List


@BACKBONES.register_module()
class HeadVIT(VisionTransformer):

    def __init__(self,
                 prompt_length: int = 1,
                 prompt_layers: List[int] = None,
                 *args,
                 **kwargs):
        super().__init__(*args, **kwargs)

        for param in self.parameters():
            param.requires_grad = False


    def forward(self, x):
        """Following mmcls implementation."""
        B = x.shape[0]
        x, patch_resolution = self.patch_embed(x)

        # stole cls_tokens impl from Phil Wang, thanks
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + resize_pos_embed(
            self.pos_embed,
            self.patch_resolution,
            patch_resolution,
            mode=self.interpolate_mode,
            num_extra_tokens=self.num_extra_tokens)
        x = self.drop_after_pos(x)


        if not self.with_cls_token:
            # Remove class token for transformer encoder input
            x = x[:, 1:]

        outs = []


        for i, layer in enumerate(self.layers):
            

            x = layer(x)

            if i == len(self.layers) - 1 and self.final_norm:
                x = self.norm1(x)

            if i in self.out_indices:
                outs.append(x[:, 0])

        return tuple(outs)
